feat: Add LoRA (Low-Rank Adaptation) support for efficient model fine-tuning#108
feat: Add LoRA (Low-Rank Adaptation) support for efficient model fine-tuning#108chen2021673 wants to merge 8 commits intomasterfrom
Conversation
- Add LoRA module infrastructure with configurable rank, alpha, dropout - Implement LoRALinear wrapper for seamless integration with Linear layers - Support tensor parallelism via LoRAParallelLinear - Add LoRAModel utility for managing multiple LoRA layers - Integrate LoRA configuration and utilities - Add GPT2 example demonstrating LoRA fine-tuning - Include comprehensive usage documentation and test suite Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- Refactor LoRA config construction with proper target module parsing - Add GetLoRAModel for in-place LoRA layer injection - Fix DDP reducer to correctly handle LoRA parameters
- Fix RowParallel/ColumnParallel LoRA input handling to match base module behavior - Add shape-based defensive checks for TP/SP consistency - Move TP/SP communication helper function declarations to utils.h - Move getter implementations from header to .cc file - Add unit test for SaveLoRAWeights/LoadLoRAWeights functionality Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- Refactor GetLoRAParameters() to retrieve only LoRA parameters for optimizer - Add MergeAndUnload() to merge weights and export as standard model - Update gpt2/llama3 examples to use new GetLoRAParameters API - Refactor LoRA linear modules and fix dimension mismatch - Improve LoRA tests and update documentation Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
| std::vector<std::shared_ptr<Tensor>> LoRAParameters() const; | ||
|
|
||
| // Accessors | ||
| int64_t in_features() const; |
There was a problem hiding this comment.
好像这几个 Accessor,也只是测试中会调用到的?
There was a problem hiding this comment.
现阶段是,但MergeAndUnload里调用in_features(),将来可能会在模型里调用
| * parallel::global::GetTensorParallelSize(), | ||
| base_module->bias(), base_module->gather_output(), base_module->input_is_parallel(), | ||
| base_module->skip_bias_add(), base_module->sequence_parallel()), | ||
| config_(config) { |
There was a problem hiding this comment.
先记录一下,可能还得讨论。
看完这个继承的版本以后突然意识到这个写法的别扭之处了:多态比如基类 A 和子类 B,基本上使用上,大部分情况应该会直接调用 B 的构造;现在程序中 A 和 B 的构造并非同时,程序会先构造 A,然后再构造 B;然而构造 B 的时候,又在构造函数中接收一个 A 实例,并使用其参数重新调用了 A 的构造函数构造父类对象。
这样的弊端:
- 写法反而更加交错复杂,没节省多少篇幅;
- 使用多态基本上是出于 A 既可是 B,B 也是 A 的场景考虑;现在这样会同时存在两个 RowParallelLinear 实例(base_module 和继承产生的子对象),原先的那个就突然弃掉了,行为上反而像 B 替代 A;
既然 B 替代 A,LoRA 实际实现还应该真的更像 decorator 的角色,更适合组合而不是继承。
There was a problem hiding this comment.
我个人还是更倾向于用继承实现 LoRALinear。从语义上看,LoRALinear 本质上仍然是 Linear 的一种实现(y = xW + xBA),更接近 is-a 而不是 has-a 的关系。另外,从 checkpoint 语义上看,LoRA 也更像是 Linear 的附加参数,而不是一种新的 module 类型。(https://huggingface.co/docs/peft/en/developer_guides/checkpoint)
These LoRA matrices are implemented as nn.Linear layers, so the parameters are stored in the .weight attribute (lora_A.weight, lora_B.weight).
comment 里提到的构造链条问题,我理解主要是当前实现通过 base_module 的参数重新构造父类对象导致的。如果改为让 LoRALinear 直接构造父类部分参数,而不是接收一个已有 Linear 再复制其状态,这个问题应该可以避免(但这个不强求修改)。
另外,关于 “使用其参数重新调用 A 的构造函数构造父类对象” 的方式,导致传入的基类成员无法访问必要内部状态的问题,可以考虑将子类声明为基类的 friend,这样子类可以直接访问必要的内部状态。
There was a problem hiding this comment.
LoRALinear 在语义上可以看作 Linear 的一种扩展,所以继承是有道理的。但除了考虑语义以外,也要考虑构造方式的合理性。
目前是先有一个 base module,再拿它的参数重新构造子类里的父类部分,因为子类是后期被构造进而替换掉原有module的,不能直接拿到构造参数,这也是我认为更像 decorator 的地方。如果子类为了完成构造,必须访问基类的大量私有内部状态,甚至要靠 friend 实现,那已经失去了子类的简洁性和自然语义,与原有父类的关系也很模糊,不如直接声明成另一个类。
There was a problem hiding this comment.
在语义上,LoRALinear 使用继承的实现是合理的,但是也得承认使用了 inject 的情况下现有的继承实现,在构造方式上有冗余的部分。至于 inject 方式的 LoRALinear 嵌入方法,我的建议是这样的:
- 首先我们设计已经定稿了,在没有严重问题的情况下我们不建议现在修改;
- 其次业界已有的实现事实上选择了这种方式,我们在通常情况下还是选择兼容他们的设计。
于是在不修改 inject 这种设计的情况下,继承和组合的实现各有自身的优点和缺点,我不继续做强制要求,在不侵入 module 基类的情况下选择权交给你。至于 inject 的设计是否有修改的必要,可以在将来有需求的时候重新讨论。
There was a problem hiding this comment.
目前如果确定不下来更优解法,就先按这个合吧,至少减少一点重复劳动。等 transformer model 那一套建设合入以后再整体看看

Summary
Added LoRA (Low-Rank Adaptation) support for parameter-efficient fine-tuning. This feature significantly reduces the number of trainable parameters through low-rank decomposition, enabling efficient fine-tuning of large models.
Changes
New Features
LoRA Infrastructure (
infini_train/include/nn/lora/):lora_config.h/cc- LoRA configuration (rank, alpha, dropout)lora_linear.h/cc- LoRA linear layer wrapperlora_model.h/cc- Multi-LoRA layer managementlora_parallel_linear.h/cc- Tensor parallelism supportlora_utils.h/cc- Utility functionsTests:
test/lora/test_lora.cc- Unit testsDocumentation:
docs/lora_usage.md- Usage documentationExamples:
example/gpt2/main.cc- Added LoRA training exampleBuild:
CMakeLists.txt- Added test_lora build targetTest Result
精度:




性能:
llama3 运行结果对比: